"""
Dataset Reconstruction Script

Description:
    This script reconstructs the dataset by extracting split zip files 
    (e.g., images_part_*.zip) and label archives into a unified directory structure.
    It automatically handles merging multi-part image archives and organizing
    detection and segmentation labels.

Dependencies:
    This script requires the 'tqdm' library for progress bars. 
    Standard libraries used: os, zipfile, glob.

Installation:
    Run the following command in your terminal to install the required library:
    pip install tqdm
"""

import os
import zipfile
import glob
from tqdm import tqdm

# --- CONFIGURATION ---
DOWNLOAD_DIR = "./downloaded_dataset"  # directory containing downloaded .zip files
OUTPUT_DIR = "./restored_dataset" # where to extract the data

def unzip_labels(zip_path, extract_to):
    """Extracts label zips into their specific folders."""
    if not os.path.exists(zip_path):
        print(f"Warning: {zip_path} not found. Skipping.")
        return

    print(f"Extracting {os.path.basename(zip_path)}...")
    os.makedirs(extract_to, exist_ok=True)
    
    with zipfile.ZipFile(zip_path, 'r') as zf:
        members = zf.infolist()
        for member in tqdm(members, desc="Extracting labels"):
            zf.extract(member, extract_to)

def unzip_image_parts(source_dir, extract_to):
    """Finds all image parts and extracts them into a single 'images' folder."""
    # Find all parts like images_part_001.zip, images_part_002.zip...
    part_files = sorted(glob.glob(os.path.join(source_dir, "images_part_*.zip")))
    
    if not part_files:
        print("No image zip parts found (images_part_*.zip).")
        return

    print(f"Found {len(part_files)} image zip parts. Beginning extraction...")
    
    # Target directory for images
    images_output_dir = os.path.join(extract_to, "images")
    os.makedirs(images_output_dir, exist_ok=True)

    for zip_file in part_files:
        print(f"   -> Extracting {os.path.basename(zip_file)}...")
        try:
            with zipfile.ZipFile(zip_file, 'r') as zf:
                # We need to handle cases where the zip might or might not have an internal 'images/' folder
                # The previous script created zips WITH an internal 'images/' folder.
                # Standard behavior: extract directly to OUTPUT_DIR, and let the zip path 'images/file.png' handle the structure.
                
                # Check first file to see if it has 'images/' prefix
                if zf.namelist() and zf.namelist()[0].startswith("images/"):
                    # Extract to the parent output dir, because zip paths already contain "images/"
                    target_path = extract_to 
                else:
                    # If zip is flat, extract into the images subfolder
                    target_path = images_output_dir

                for member in tqdm(zf.infolist(), desc="Unzipping chunk", leave=False):
                    zf.extract(member, target_path)
                    
        except zipfile.BadZipFile:
            print(f"Error: {zip_file} is corrupted or not a valid zip file.")
        except Exception as e:
            print(f"Error processing {zip_file}: {e}")

    print("Image extraction complete.")

def verify_counts(output_dir):
    """Simple check to print what was extracted."""
    img_count = len(glob.glob(os.path.join(output_dir, "images", "*.png")))
    det_count = len(glob.glob(os.path.join(output_dir, "detection_labels", "*.txt")))
    seg_count = len(glob.glob(os.path.join(output_dir, "segmentation_labels", "*.txt")))
    
    print("\n--- Extraction Summary ---")
    print(f"Location: {os.path.abspath(output_dir)}")
    print(f" Images: {img_count}")
    print(f" Detection Labels: {det_count}")
    print(f" Segmentation Labels: {seg_count}")
    
    if img_count == 0:
        print("\nWARNING: No images were found. Did you download all the parts?")

if __name__ == "__main__":
    print("--- Dataset Reconstruction Script ---")
    
    # 1. Extract Detection Labels
    unzip_labels(
        os.path.join(DOWNLOAD_DIR, "detection_labels.zip"), 
        os.path.join(OUTPUT_DIR, "detection_labels")
    )
    
    # 2. Extract Segmentation Labels
    unzip_labels(
        os.path.join(DOWNLOAD_DIR, "segmentation_labels.zip"), 
        os.path.join(OUTPUT_DIR, "segmentation_labels")
    )
    
    # 3. Extract Images (Merged from parts)
    unzip_image_parts(DOWNLOAD_DIR, OUTPUT_DIR)
    
    # 4. Verification
    verify_counts(OUTPUT_DIR)